{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# 自然语言推理和数据集\n",
    "\n",
    "斯坦福大学自然语言推理(SNLI)语料库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l\n",
    "\n",
    "d2l.DATA_HUB['SNLI'] = (\n",
    "    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',\n",
    "    '9fcde07509c7e87ec61c640c1b2753d9041758e4')\n",
    "\n",
    "data_dir = d2l.download_extract('SNLI')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "阅读数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "origin_pos": 4,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def read_snli(data_dir, is_train):\n",
    "    \"\"\"Read the SNLI dataset into premises, hypotheses, and labels.\"\"\"\n",
    "    def extract_text(s):\n",
    "        s = re.sub('\\\\(', '', s)\n",
    "        s = re.sub('\\\\)', '', s)\n",
    "        s = re.sub('\\\\s{2,}', ' ', s)\n",
    "        return s.strip()\n",
    "    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}\n",
    "    file_name = os.path.join(data_dir, 'snli_1.0_train.txt'\n",
    "                             if is_train else 'snli_1.0_test.txt')\n",
    "    with open(file_name, 'r') as f:\n",
    "        rows = [row.split('\\t') for row in f.readlines()[1:]]\n",
    "    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]\n",
    "    hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set]\n",
    "    labels = [label_set[row[0]] for row in rows if row[0] in label_set]\n",
    "    return premises, hypotheses, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "打印前3对前提和假设"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "origin_pos": 6,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "premise: A person on a horse jumps over a broken down airplane .\n",
      "hypothesis: A person is training his horse for a competition .\n",
      "label: 2\n",
      "premise: A person on a horse jumps over a broken down airplane .\n",
      "hypothesis: A person is at a diner , ordering an omelette .\n",
      "label: 1\n",
      "premise: A person on a horse jumps over a broken down airplane .\n",
      "hypothesis: A person is outdoors , on a horse .\n",
      "label: 0\n"
     ]
    }
   ],
   "source": [
    "train_data = read_snli(data_dir, is_train=True)\n",
    "for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):\n",
    "    print('premise:', x0)\n",
    "    print('hypothesis:', x1)\n",
    "    print('label:', y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "“蕴含（entailment）”、“矛盾（contradiction）”和“中性（neutral）”三个标签都是平衡的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[183416, 183187, 182764]\n",
      "[3368, 3237, 3219]\n"
     ]
    }
   ],
   "source": [
    "test_data = read_snli(data_dir, is_train=False)\n",
    "for data in [train_data, test_data]:\n",
    "    print([[row for row in data[2]].count(i) for i in range(3)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "定义用于加载数据集的类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "origin_pos": 11,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class SNLIDataset(torch.utils.data.Dataset):\n",
    "    \"\"\"A customized dataset to load the SNLI dataset.\"\"\"\n",
    "    def __init__(self, dataset, num_steps, vocab=None):\n",
    "        self.num_steps = num_steps\n",
    "        all_premise_tokens = d2l.tokenize(dataset[0])\n",
    "        all_hypothesis_tokens = d2l.tokenize(dataset[1])\n",
    "        if vocab is None:\n",
    "            self.vocab = d2l.Vocab(all_premise_tokens + all_hypothesis_tokens,\n",
    "                                   min_freq=5, reserved_tokens=['<pad>'])\n",
    "        else:\n",
    "            self.vocab = vocab\n",
    "        self.premises = self._pad(all_premise_tokens)\n",
    "        self.hypotheses = self._pad(all_hypothesis_tokens)\n",
    "        self.labels = torch.tensor(dataset[2])\n",
    "        print('read ' + str(len(self.premises)) + ' examples')\n",
    "\n",
    "    def _pad(self, lines):\n",
    "        return torch.tensor([d2l.truncate_pad(\n",
    "            self.vocab[line], self.num_steps, self.vocab['<pad>'])\n",
    "                         for line in lines])\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.premises)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "统筹合一"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "read 549367 examples\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "read 9824 examples\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "18678"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def load_data_snli(batch_size, num_steps=50):\n",
    "    \"\"\"Download the SNLI dataset and return data iterators and vocabulary.\"\"\"\n",
    "    num_workers = d2l.get_dataloader_workers()\n",
    "    data_dir = d2l.download_extract('SNLI')\n",
    "    train_data = read_snli(data_dir, True)\n",
    "    test_data = read_snli(data_dir, False)\n",
    "    train_set = SNLIDataset(train_data, num_steps)\n",
    "    test_set = SNLIDataset(test_data, num_steps, train_set.vocab)\n",
    "    train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n",
    "                                             shuffle=True,\n",
    "                                             num_workers=num_workers)\n",
    "    test_iter = torch.utils.data.DataLoader(test_set, batch_size,\n",
    "                                            shuffle=False,\n",
    "                                            num_workers=num_workers)\n",
    "    return train_iter, test_iter, train_set.vocab\n",
    "\n",
    "train_iter, test_iter, vocab = load_data_snli(128, 50)\n",
    "len(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "origin_pos": 18,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([128, 50])\n",
      "torch.Size([128, 50])\n",
      "torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "for X, Y in train_iter:\n",
    "    print(X[0].shape)\n",
    "    print(X[1].shape)\n",
    "    print(Y.shape)\n",
    "    break"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "language_info": {
   "name": "python"
  },
  "rise": {
   "autolaunch": true,
   "enable_chalkboard": true,
   "overlay": "<div class='my-top-right'><img height=80px src='http://d2l.ai/_static/logo-with-text.png'/></div><div class='my-top-left'></div>",
   "scroll": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}